{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 在训练作业中使用checkpoint\n",
    "\n",
    "在机器学习模型训练过程中，往往需要较长的时间完成训练数据的迭代，实现模型的收敛，然而训练过程可能会因为各种原因中断，例如机器故障、网络问题、或是代码原因等。为了避免中断后需要重头开始训练，开发者通常会在训练过程中，定期将模型的状态保存为`checkpoint`文件，以便在训练中断后，能够从保存的`checkpoint`文件获取模型参数，优化器状态，训练步数等训练状态，恢复训练。\n",
    "\n",
    "本文档介绍如何在PAI的训练作业中使用checkpoint。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 准备工作\n",
    "\n",
    "我们需要首先安装PAI Python SDK以运行本示例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m pip install --upgrade alipai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "SDK 需要配置访问阿里云服务需要的 AccessKey，以及当前使用的工作空间和OSS Bucket。在 PAI Python SDK 安装之后，通过在 **命令行终端** 中执行以下命令，按照引导配置密钥，工作空间等信息。\n",
    "\n",
    "\n",
    "```shell\n",
    "\n",
    "# 以下命令，请在 命令行终端 中执行.\n",
    "\n",
    "python -m pai.toolkit.config\n",
    "\n",
    "```\n",
    "\n",
    "\n",
    "我们可以通过以下代码验证当前的配置。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用checkpoint保存和恢复训练作业\n",
    "\n",
    "当使用SDK提供的`pai.estimator.Estimator` 提交训练作业时，训练作业默认会挂载用户的OSS Bucket路径到训练作业的`/ml/output/checkpoints`目录。训练代码可以将checkpoint文件写出到对应的路径，从而保存到OSS中。提交训练作业之后，可以通过 `estimator.checkpoints_data()` 方法可以获取`checkpoints`保存的OSS路径。\n",
    "\n",
    "当需要使用已有的`checkpoint`时，用户可以通过 `checkpoints_path` 参数指定一个OSS Bucket路径，PAI会将该路径挂载到训练作业的`/ml/output/checkpoints`目录，训练作业可以通过读取对应数据路径下的checkpoint文件来恢复训练。\n",
    "\n",
    "\n",
    "\n",
    "```python\n",
    "\n",
    "from pai.estimator import Estimator\n",
    "\n",
    "\n",
    "# 1. 使用默认的checkpoints路径保存模型的checkpoints\n",
    "est = Estimator(\n",
    "\timage_uri=\"<TrainingImageUri>\",\n",
    "\tcommand=\"python train.py\",\n",
    ")\n",
    "\n",
    "# 训练作业默认会挂载一个OSS Bucket路径到 /ml/output/checkpoints\n",
    "# 用户训练代码可以通过写文件到 /ml/output/checkpoints 保存checkpoint\n",
    "est.fit()\n",
    "\n",
    "# 查看训练作业的checkpoints路径\n",
    "print(est.checkpoints_data())\n",
    "\n",
    "# 2. 使用其他训练作业产出的checkpoints恢复训练\n",
    "est_load = Estimator(\n",
    "\timage_uri=\"<TrainingImageUri>\",\n",
    "\tcommand=\"python train.py\",\n",
    "\t# 指定使用上一个训练作业输出的checkpoints.\n",
    "\tcheckpoints_path=est.checkpoints_data(),\n",
    ")\n",
    "\n",
    "# 训练代码从 /ml/output/checkpoints 中加载checkpoint\n",
    "est_load.fit()\n",
    "\n",
    "```\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 在PyTorch中使用checkpoint\n",
    "\n",
    "在PyTorch中，通常使用`torch.save`方法将模型的参数、优化器的状态、训练进度等信息，以字典的形式作为`checkpoint`进行保存。保存的`checkpoint`文件可以通过 `torch.load` 进行加载。PyTorch提供了如何在训练中保存和加载checkpoint的教程：[Save And Loading A General Checkpoint In PyTorch](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html)。\n",
    "\n",
    "我们将基于PyTorch的示例教程，演示如何在PAI的训练作业中使用checkpoint。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练作业使用的代码如下:\n",
    "\n",
    "1. 在训练开始之前，通过 `/ml/output/checkpoints/` 路径加载checkpoint获取初始化模型参数，优化器，以及训练进度。\n",
    "\n",
    "2. 基于checkpoint的状态信息训练模型，在训练过程中，定期保存checkpoint到 `/ml/output/checkpoints/` 路径。\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p train_src"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile train_src/train.py\n",
    "# Additional information\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "EPOCH = 5\n",
    "CHECKPOINT_NAME = \"checkpoint.pt\"\n",
    "LOSS = 0.4\n",
    "\n",
    "# Define a custom mock dataset\n",
    "class RandomDataset(Dataset):\n",
    "    def __init__(self, num_samples=1000):\n",
    "        self.num_samples = num_samples\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_samples\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        x = torch.randn(10)  # Generating random input tensor\n",
    "        y = torch.randint(0, 2, (1,)).item()  # Generating random target label (0 or 1)\n",
    "        return x, y\n",
    "\n",
    "\n",
    "# Define your model\n",
    "class MyModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyModel, self).__init__()\n",
    "        self.fc = nn.Linear(10, 2)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.fc(x)\n",
    "\n",
    "\n",
    "net = MyModel()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.001)\n",
    "start_epoch = 0\n",
    "\n",
    "def load_checkpoint():\n",
    "    \"\"\"Load checkpoint if exists.\"\"\"\n",
    "    global net, optimizer, start_epoch, LOSS\n",
    "    checkpoint_dir = os.environ.get(\"PAI_OUTPUT_CHECKPOINTS\")\n",
    "    if not checkpoint_dir:\n",
    "        return\n",
    "    checkpoint_path = os.path.join(checkpoint_dir, CHECKPOINT_NAME)\n",
    "    if not os.path.exists(checkpoint_path):\n",
    "        return\n",
    "    data = torch.load(checkpoint_path)\n",
    "\n",
    "    net.load_state_dict(data[\"model_state_dict\"])\n",
    "    optimizer.load_state_dict(data[\"optimizer_state_dict\"])\n",
    "    start_epoch = data[\"epoch\"]\n",
    "\n",
    "\n",
    "def save_checkpoint(epoch):\n",
    "    global net, optimizer, start_epoch, LOSS\n",
    "    checkpoint_dir = os.environ.get(\"PAI_OUTPUT_CHECKPOINTS\")\n",
    "    if not checkpoint_dir:\n",
    "        return\n",
    "    checkpoint_path = os.path.join(checkpoint_dir, CHECKPOINT_NAME)\n",
    "    torch.save({\n",
    "        'epoch': epoch + 1,\n",
    "        'model_state_dict': net.state_dict(),\n",
    "        'optimizer_state_dict': optimizer.state_dict(),\n",
    "    }, checkpoint_path)\n",
    "\n",
    "\n",
    "def parse_args():\n",
    "    import argparse\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--epochs\", type=int, default=10)\n",
    "    args = parser.parse_args()\n",
    "    return args\n",
    "\n",
    "\n",
    "def train():\n",
    "    args = parse_args()\n",
    "    load_checkpoint()\n",
    "    batch_size = 4\n",
    "    dataloader = DataLoader(RandomDataset(), batch_size=batch_size, shuffle=True)\n",
    "    num_epochs = args.epochs\n",
    "    print(num_epochs)\n",
    "    for epoch in range(start_epoch, num_epochs):\n",
    "        net.train()\n",
    "        for i, (inputs, targets) in enumerate(dataloader):\n",
    "            # Forward pass\n",
    "            outputs = net(inputs)\n",
    "            loss = criterion(outputs, targets)\n",
    "            \n",
    "            # Backward pass and optimization\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            # Print training progress\n",
    "            if (i+1) % 10 == 0:\n",
    "                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}')\n",
    "        \n",
    "        # Save checkpoint\n",
    "        save_checkpoint(epoch=epoch)\n",
    "    # save the model\n",
    "    torch.save(net.state_dict(), os.path.join(os.environ.get(\"PAI_OUTPUT_MODEL\", \".\"), \"model.pt\"))\n",
    "    \n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将以上的代码提交到PAI执行，训练作业最终提供挂载的OSS路径保存模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liangquan/code/pypai/pai/common/oss_utils.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  from tqdm.autonotebook import tqdm\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8266991a0d042c6a54531f252ecc727",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading file: /var/folders/hc/5w4bg25j1ns2mm0yb06zzzbh0000gp/T/tmpt3_0rsuf/source.tar.gz:   0%|          | 0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "View the job detail by accessing the console URI: https://pai.console.aliyun.com/?regionId=cn-hangzhou&workspaceId=58670#/training/jobs/train1u1it512gqg\n",
      "TrainingJob launch starting\n",
      "MAX_PARALLELISM=0\n",
      "C_INCLUDE_PATH=/home/pai/include\n",
      "KUBERNETES_PORT=tcp://10.192.0.1:443\n",
      "KUBERNETES_SERVICE_PORT=443\n",
      "LANGUAGE=en_US.UTF-8\n",
      "PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com\n",
      "MASTER_ADDR=train1u1it512gqg-master-0\n",
      "HOSTNAME=train1u1it512gqg-master-0\n",
      "LD_LIBRARY_PATH=:/lib/x86_64-linux-gnu:/home/pai/lib:/home/pai/jre/lib/amd64/server\n",
      "MASTER_PORT=23456\n",
      "HOME=/root\n",
      "PAI_USER_ARGS=\n",
      "PYTHONUNBUFFERED=0\n",
      "PAI_OUTPUT_CHECKPOINTS=/ml/output/checkpoints/\n",
      "PAI_CONFIG_DIR=/ml/input/config/\n",
      "WORLD_SIZE=1\n",
      "REGION_ID=cn-hangzhou\n",
      "CPLUS_INCLUDE_PATH=/home/pai/include\n",
      "RANK=0\n",
      "OPAL_PREFIX=/home/pai/\n",
      "PAI_TRAINING_JOB_ID=train1u1it512gqg\n",
      "TERM=xterm-color\n",
      "KUBERNETES_PORT_443_TCP_ADDR=10.192.0.1\n",
      "PAI_OUTPUT_MODEL=/ml/output/model/\n",
      "ELASTIC_TRAINING_ENABLED=false\n",
      "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/home/pai/bin:/home/pai/hadoop/bin\n",
      "PIP_INDEX_URL=https://mirrors.cloud.aliyuncs.com/pypi/simple\n",
      "KUBERNETES_PORT_443_TCP_PORT=443\n",
      "KUBERNETES_PORT_443_TCP_PROTO=tcp\n",
      "LANG=en_US.UTF-8\n",
      "aliyun_logs_containerType_tags=containerType=Algorithm\n",
      "PAI_TRAINING_USE_ECI=true\n",
      "KUBERNETES_SERVICE_PORT_HTTPS=443\n",
      "KUBERNETES_PORT_443_TCP=tcp://10.192.0.1:443\n",
      "ELASTIC_INFERENCE_ENABLED=false\n",
      "LC_ALL=en_US.UTF-8\n",
      "JAVA_HOME=/home/pai\n",
      "KUBERNETES_SERVICE_HOST=10.192.0.1\n",
      "PWD=/\n",
      "PAI_HPS={}\n",
      "TZ=UTC\n",
      "HADOOP_HOME=/home/pai/hadoop\n",
      "PAI_OUTPUT_LOGS=/ml/output/logs/\n",
      "aliyun_logs_trainingJobId_tags=trainingJobId=train1u1it512gqg\n",
      "PAI_ODPS_CREDENTIAL=/ml/input/credential/odps.json\n",
      "PAI_WORKING_DIR=/ml/usercode/\n",
      "Change to Working Directory, /ml/usercode/\n",
      "User program launching\n",
      "-----------------------------------------------------------------\n",
      "10\n",
      "Epoch [1/10], Step [10/250], Loss: 0.3664854168891907\n",
      "Epoch [1/10], Step [20/250], Loss: 0.5867650508880615\n",
      "Epoch [1/10], Step [30/250], Loss: 0.8810225129127502\n",
      "Epoch [1/10], Step [40/250], Loss: 1.3596220016479492\n",
      "Epoch [1/10], Step [50/250], Loss: 1.0757191181182861\n",
      "Epoch [1/10], Step [60/250], Loss: 0.5261836051940918\n",
      "Epoch [1/10], Step [70/250], Loss: 1.0891999006271362\n",
      "Epoch [1/10], Step [80/250], Loss: 1.2425217628479004\n",
      "Epoch [1/10], Step [90/250], Loss: 0.7928518652915955\n",
      "Epoch [1/10], Step [100/250], Loss: 0.500701367855072\n",
      "Epoch [1/10], Step [110/250], Loss: 1.1105762720108032\n",
      "Epoch [1/10], Step [120/250], Loss: 0.7642831802368164\n",
      "Epoch [1/10], Step [130/250], Loss: 0.9435116052627563\n",
      "Epoch [1/10], Step [140/250], Loss: 0.4632255434989929\n",
      "Epoch [1/10], Step [150/250], Loss: 0.8282555937767029\n",
      "Epoch [1/10], Step [160/250], Loss: 0.5644117593765259\n",
      "Epoch [1/10], Step [170/250], Loss: 0.8821360468864441\n",
      "Epoch [1/10], Step [180/250], Loss: 0.6495410799980164\n",
      "Epoch [1/10], Step [190/250], Loss: 0.6814499497413635\n",
      "Epoch [1/10], Step [200/250], Loss: 1.1818656921386719\n",
      "Epoch [1/10], Step [210/250], Loss: 0.4218548536300659\n",
      "Epoch [1/10], Step [220/250], Loss: 0.5892952680587769\n",
      "Epoch [1/10], Step [230/250], Loss: 0.8104468584060669\n",
      "Epoch [1/10], Step [240/250], Loss: 0.3310832977294922\n",
      "Epoch [1/10], Step [250/250], Loss: 1.0296210050582886\n",
      "Epoch [2/10], Step [10/250], Loss: 0.747037947177887\n",
      "Epoch [2/10], Step [20/250], Loss: 1.0555682182312012\n",
      "Epoch [2/10], Step [30/250], Loss: 0.5005624294281006\n",
      "Epoch [2/10], Step [40/250], Loss: 0.6007864475250244\n",
      "Epoch [2/10], Step [50/250], Loss: 0.8172819018363953\n",
      "Epoch [2/10], Step [60/250], Loss: 0.7322960495948792\n",
      "Epoch [2/10], Step [70/250], Loss: 0.6178841590881348\n",
      "Epoch [2/10], Step [80/250], Loss: 0.9776118993759155\n",
      "Epoch [2/10], Step [90/250], Loss: 0.8088865876197815\n",
      "Epoch [2/10], Step [100/250], Loss: 0.7169486284255981\n",
      "Epoch [2/10], Step [110/250], Loss: 0.8003190159797668\n",
      "Epoch [2/10], Step [120/250], Loss: 0.9178279638290405\n",
      "Epoch [2/10], Step [130/250], Loss: 0.5217956900596619\n",
      "Epoch [2/10], Step [140/250], Loss: 1.2751939296722412\n",
      "Epoch [2/10], Step [150/250], Loss: 1.1024904251098633\n",
      "Epoch [2/10], Step [160/250], Loss: 0.6336060762405396\n",
      "Epoch [2/10], Step [170/250], Loss: 0.799022376537323\n",
      "Epoch [2/10], Step [180/250], Loss: 0.7938567996025085\n",
      "Epoch [2/10], Step [190/250], Loss: 1.060591220855713\n",
      "Epoch [2/10], Step [200/250], Loss: 0.9365970492362976\n",
      "Epoch [2/10], Step [210/250], Loss: 0.6945515871047974\n",
      "Epoch [2/10], Step [220/250], Loss: 0.4772261381149292\n",
      "Epoch [2/10], Step [230/250], Loss: 1.0332412719726562\n",
      "Epoch [2/10], Step [240/250], Loss: 0.7284632325172424\n",
      "Epoch [2/10], Step [250/250], Loss: 0.4485410451889038\n",
      "Epoch [3/10], Step [10/250], Loss: 0.7845520377159119\n",
      "Epoch [3/10], Step [20/250], Loss: 0.5619648694992065\n",
      "Epoch [3/10], Step [30/250], Loss: 0.725273609161377\n",
      "Epoch [3/10], Step [40/250], Loss: 0.7783026695251465\n",
      "Epoch [3/10], Step [50/250], Loss: 0.5168777704238892\n",
      "Epoch [3/10], Step [60/250], Loss: 0.67060387134552\n",
      "Epoch [3/10], Step [70/250], Loss: 0.9300781488418579\n",
      "Epoch [3/10], Step [80/250], Loss: 0.6534505486488342\n",
      "Epoch [3/10], Step [90/250], Loss: 0.557340681552887\n",
      "Epoch [3/10], Step [100/250], Loss: 0.667724609375\n",
      "Epoch [3/10], Step [110/250], Loss: 0.5125826001167297\n",
      "Epoch [3/10], Step [120/250], Loss: 0.4494149088859558\n",
      "Epoch [3/10], Step [130/250], Loss: 0.6902559995651245\n",
      "Epoch [3/10], Step [140/250], Loss: 0.5450549125671387\n",
      "Epoch [3/10], Step [150/250], Loss: 1.0632681846618652\n",
      "Epoch [3/10], Step [160/250], Loss: 0.7964761257171631\n",
      "Epoch [3/10], Step [170/250], Loss: 0.5218536257743835\n",
      "Epoch [3/10], Step [180/250], Loss: 0.6972622275352478\n",
      "Epoch [3/10], Step [190/250], Loss: 0.7963941097259521\n",
      "Epoch [3/10], Step [200/250], Loss: 0.5798731446266174\n",
      "Epoch [3/10], Step [210/250], Loss: 0.7930802702903748\n",
      "Epoch [3/10], Step [220/250], Loss: 0.7618649005889893\n",
      "Epoch [3/10], Step [230/250], Loss: 0.9831617474555969\n",
      "Epoch [3/10], Step [240/250], Loss: 0.7935497164726257\n",
      "Epoch [3/10], Step [250/250], Loss: 0.9747794270515442\n",
      "Epoch [4/10], Step [10/250], Loss: 0.6432996392250061\n",
      "Epoch [4/10], Step [20/250], Loss: 0.6515889167785645\n",
      "Epoch [4/10], Step [30/250], Loss: 0.8191264867782593\n",
      "Epoch [4/10], Step [40/250], Loss: 0.5717310905456543\n",
      "Epoch [4/10], Step [50/250], Loss: 1.0365064144134521\n",
      "Epoch [4/10], Step [60/250], Loss: 0.7181562185287476\n",
      "Epoch [4/10], Step [70/250], Loss: 0.6014276146888733\n",
      "Epoch [4/10], Step [80/250], Loss: 0.8743482232093811\n",
      "Epoch [4/10], Step [90/250], Loss: 0.5963127613067627\n",
      "Epoch [4/10], Step [100/250], Loss: 0.7012943029403687\n",
      "Epoch [4/10], Step [110/250], Loss: 0.6271654367446899\n",
      "Epoch [4/10], Step [120/250], Loss: 0.646144449710846\n",
      "Epoch [4/10], Step [130/250], Loss: 0.5112266540527344\n",
      "Epoch [4/10], Step [140/250], Loss: 0.8657329678535461\n",
      "Epoch [4/10], Step [150/250], Loss: 0.677897572517395\n",
      "Epoch [4/10], Step [160/250], Loss: 0.798669695854187\n",
      "Epoch [4/10], Step [170/250], Loss: 0.805213451385498\n",
      "Epoch [4/10], Step [180/250], Loss: 0.7744658589363098\n",
      "Epoch [4/10], Step [190/250], Loss: 0.4748728275299072\n",
      "Epoch [4/10], Step [200/250], Loss: 0.6623726487159729\n",
      "Epoch [4/10], Step [210/250], Loss: 0.6851851940155029\n",
      "Epoch [4/10], Step [220/250], Loss: 0.5917701721191406\n",
      "Epoch [4/10], Step [230/250], Loss: 0.586968719959259\n",
      "Epoch [4/10], Step [240/250], Loss: 0.758073091506958\n",
      "Epoch [4/10], Step [250/250], Loss: 0.7908360958099365\n",
      "Epoch [5/10], Step [10/250], Loss: 0.747495174407959\n",
      "Epoch [5/10], Step [20/250], Loss: 0.7880417108535767\n",
      "Epoch [5/10], Step [30/250], Loss: 1.4239259958267212\n",
      "Epoch [5/10], Step [40/250], Loss: 0.709957480430603\n",
      "Epoch [5/10], Step [50/250], Loss: 0.45279955863952637\n",
      "Epoch [5/10], Step [60/250], Loss: 0.6855078935623169\n",
      "Epoch [5/10], Step [70/250], Loss: 0.7050631046295166\n",
      "Epoch [5/10], Step [80/250], Loss: 0.8256967067718506\n",
      "Epoch [5/10], Step [90/250], Loss: 0.9627029895782471\n",
      "Epoch [5/10], Step [100/250], Loss: 0.7069070339202881\n",
      "Epoch [5/10], Step [110/250], Loss: 0.6772119998931885\n",
      "Epoch [5/10], Step [120/250], Loss: 0.5547316670417786\n",
      "Epoch [5/10], Step [130/250], Loss: 0.4749568998813629\n",
      "Epoch [5/10], Step [140/250], Loss: 0.5910231471061707\n",
      "Epoch [5/10], Step [150/250], Loss: 0.5789163112640381\n",
      "Epoch [5/10], Step [160/250], Loss: 0.994613766670227\n",
      "Epoch [5/10], Step [170/250], Loss: 0.7664419412612915\n",
      "Epoch [5/10], Step [180/250], Loss: 0.7812412977218628\n",
      "Epoch [5/10], Step [190/250], Loss: 0.932634174823761\n",
      "Epoch [5/10], Step [200/250], Loss: 0.4732060134410858\n",
      "Epoch [5/10], Step [210/250], Loss: 0.6712639927864075\n",
      "Epoch [5/10], Step [220/250], Loss: 0.7019771337509155\n",
      "Epoch [5/10], Step [230/250], Loss: 0.668921709060669\n",
      "Epoch [5/10], Step [240/250], Loss: 0.5486156344413757\n",
      "Epoch [5/10], Step [250/250], Loss: 0.8131189346313477\n",
      "Epoch [6/10], Step [10/250], Loss: 0.5800281167030334\n",
      "Epoch [6/10], Step [20/250], Loss: 0.9032570719718933\n",
      "Epoch [6/10], Step [30/250], Loss: 0.6829659938812256\n",
      "Epoch [6/10], Step [40/250], Loss: 0.577970027923584\n",
      "Epoch [6/10], Step [50/250], Loss: 0.9745671153068542\n",
      "Epoch [6/10], Step [60/250], Loss: 0.6292040348052979\n",
      "Epoch [6/10], Step [70/250], Loss: 0.9189562201499939\n",
      "Epoch [6/10], Step [80/250], Loss: 1.0687212944030762\n",
      "Epoch [6/10], Step [90/250], Loss: 0.6210573315620422\n",
      "Epoch [6/10], Step [100/250], Loss: 0.7758654356002808\n",
      "Epoch [6/10], Step [110/250], Loss: 1.055539846420288\n",
      "Epoch [6/10], Step [120/250], Loss: 0.7991855144500732\n",
      "Epoch [6/10], Step [130/250], Loss: 0.8390480279922485\n",
      "Epoch [6/10], Step [140/250], Loss: 0.5641282200813293\n",
      "Epoch [6/10], Step [150/250], Loss: 0.5416208505630493\n",
      "Epoch [6/10], Step [160/250], Loss: 0.8556939363479614\n",
      "Epoch [6/10], Step [170/250], Loss: 0.8848042488098145\n",
      "Epoch [6/10], Step [180/250], Loss: 0.6585526466369629\n",
      "Epoch [6/10], Step [190/250], Loss: 0.5264347791671753\n",
      "Epoch [6/10], Step [200/250], Loss: 0.7451325058937073\n",
      "Epoch [6/10], Step [210/250], Loss: 0.8498039841651917\n",
      "Epoch [6/10], Step [220/250], Loss: 0.9514821767807007\n",
      "Epoch [6/10], Step [230/250], Loss: 0.5831080675125122\n",
      "Epoch [6/10], Step [240/250], Loss: 0.7323013544082642\n",
      "Epoch [6/10], Step [250/250], Loss: 0.799047589302063\n",
      "Epoch [7/10], Step [10/250], Loss: 0.7431624531745911\n",
      "Epoch [7/10], Step [20/250], Loss: 0.7462856769561768\n",
      "Epoch [7/10], Step [30/250], Loss: 0.7637103796005249\n",
      "Epoch [7/10], Step [40/250], Loss: 0.7512863874435425\n",
      "Epoch [7/10], Step [50/250], Loss: 0.8934370279312134\n",
      "Epoch [7/10], Step [60/250], Loss: 0.6657339334487915\n",
      "Epoch [7/10], Step [70/250], Loss: 0.7996265292167664\n",
      "Epoch [7/10], Step [80/250], Loss: 0.7883811593055725\n",
      "Epoch [7/10], Step [90/250], Loss: 0.7327611446380615\n",
      "Epoch [7/10], Step [100/250], Loss: 0.7103905081748962\n",
      "Epoch [7/10], Step [110/250], Loss: 0.8145009875297546\n",
      "Epoch [7/10], Step [120/250], Loss: 0.6999544501304626\n",
      "Epoch [7/10], Step [130/250], Loss: 0.6132965087890625\n",
      "Epoch [7/10], Step [140/250], Loss: 0.8219666481018066\n",
      "Epoch [7/10], Step [150/250], Loss: 0.573877215385437\n",
      "Epoch [7/10], Step [160/250], Loss: 0.864593505859375\n",
      "Epoch [7/10], Step [170/250], Loss: 0.7187140583992004\n",
      "Epoch [7/10], Step [180/250], Loss: 0.601334810256958\n",
      "Epoch [7/10], Step [190/250], Loss: 0.6193158626556396\n",
      "Epoch [7/10], Step [200/250], Loss: 0.7600311040878296\n",
      "Epoch [7/10], Step [210/250], Loss: 0.6659085154533386\n",
      "Epoch [7/10], Step [220/250], Loss: 0.6364413499832153\n",
      "Epoch [7/10], Step [230/250], Loss: 0.878304123878479\n",
      "Epoch [7/10], Step [240/250], Loss: 0.7139410972595215\n",
      "Epoch [7/10], Step [250/250], Loss: 0.6852972507476807\n",
      "Epoch [8/10], Step [10/250], Loss: 1.0263853073120117\n",
      "Epoch [8/10], Step [20/250], Loss: 0.7559791803359985\n",
      "Epoch [8/10], Step [30/250], Loss: 0.6709325313568115\n",
      "Epoch [8/10], Step [40/250], Loss: 0.5146634578704834\n",
      "Epoch [8/10], Step [50/250], Loss: 0.6418485641479492\n",
      "Epoch [8/10], Step [60/250], Loss: 0.72318035364151\n",
      "Epoch [8/10], Step [70/250], Loss: 0.7116968631744385\n",
      "Epoch [8/10], Step [80/250], Loss: 0.7035868763923645\n",
      "Epoch [8/10], Step [90/250], Loss: 0.6085933446884155\n",
      "Epoch [8/10], Step [100/250], Loss: 0.5128545761108398\n",
      "Epoch [8/10], Step [110/250], Loss: 0.6380510330200195\n",
      "Epoch [8/10], Step [120/250], Loss: 0.4963105320930481\n",
      "Epoch [8/10], Step [130/250], Loss: 0.6693160533905029\n",
      "Epoch [8/10], Step [140/250], Loss: 0.6602588891983032\n",
      "Epoch [8/10], Step [150/250], Loss: 0.8440876007080078\n",
      "Epoch [8/10], Step [160/250], Loss: 0.7596740126609802\n",
      "Epoch [8/10], Step [170/250], Loss: 0.695992112159729\n",
      "Epoch [8/10], Step [180/250], Loss: 0.6737014651298523\n",
      "Epoch [8/10], Step [190/250], Loss: 0.6722623705863953\n",
      "Epoch [8/10], Step [200/250], Loss: 0.5857406854629517\n",
      "Epoch [8/10], Step [210/250], Loss: 0.9563039541244507\n",
      "Epoch [8/10], Step [220/250], Loss: 0.7375826835632324\n",
      "Epoch [8/10], Step [230/250], Loss: 0.8751094341278076\n",
      "Epoch [8/10], Step [240/250], Loss: 0.7180076837539673\n",
      "Epoch [8/10], Step [250/250], Loss: 0.6384711861610413\n",
      "Epoch [9/10], Step [10/250], Loss: 0.6789698004722595\n",
      "Epoch [9/10], Step [20/250], Loss: 0.6645065546035767\n",
      "Epoch [9/10], Step [30/250], Loss: 0.6996726989746094\n",
      "Epoch [9/10], Step [40/250], Loss: 0.7402397394180298\n",
      "Epoch [9/10], Step [50/250], Loss: 0.6388964653015137\n",
      "Epoch [9/10], Step [60/250], Loss: 0.9401450753211975\n",
      "Epoch [9/10], Step [70/250], Loss: 0.6708970665931702\n",
      "Epoch [9/10], Step [80/250], Loss: 0.728550136089325\n",
      "Epoch [9/10], Step [90/250], Loss: 0.7362596988677979\n",
      "Epoch [9/10], Step [100/250], Loss: 0.7750495672225952\n",
      "Epoch [9/10], Step [110/250], Loss: 0.807244062423706\n",
      "Epoch [9/10], Step [120/250], Loss: 0.754521369934082\n",
      "Epoch [9/10], Step [130/250], Loss: 0.5469345450401306\n",
      "Epoch [9/10], Step [140/250], Loss: 0.8965460062026978\n",
      "Epoch [9/10], Step [150/250], Loss: 0.7952369451522827\n",
      "Epoch [9/10], Step [160/250], Loss: 0.6263578534126282\n",
      "Epoch [9/10], Step [170/250], Loss: 0.5788871049880981\n",
      "Epoch [9/10], Step [180/250], Loss: 0.7363749146461487\n",
      "Epoch [9/10], Step [190/250], Loss: 0.7322844862937927\n",
      "Epoch [9/10], Step [200/250], Loss: 0.6707043051719666\n",
      "Epoch [9/10], Step [210/250], Loss: 0.7251213192939758\n",
      "Epoch [9/10], Step [220/250], Loss: 0.6435517072677612\n",
      "Epoch [9/10], Step [230/250], Loss: 0.534774124622345\n",
      "Epoch [9/10], Step [240/250], Loss: 0.6989405751228333\n",
      "Epoch [9/10], Step [250/250], Loss: 0.7413943409919739\n",
      "Epoch [10/10], Step [10/250], Loss: 0.6014090776443481\n",
      "Epoch [10/10], Step [20/250], Loss: 0.8173813819885254\n",
      "Epoch [10/10], Step [30/250], Loss: 0.8984671235084534\n",
      "Epoch [10/10], Step [40/250], Loss: 0.6354056000709534\n",
      "Epoch [10/10], Step [50/250], Loss: 0.7964866757392883\n",
      "Epoch [10/10], Step [60/250], Loss: 0.7849454879760742\n",
      "Epoch [10/10], Step [70/250], Loss: 0.5637381076812744\n",
      "Epoch [10/10], Step [80/250], Loss: 0.7669687271118164\n",
      "Epoch [10/10], Step [90/250], Loss: 0.6140038371086121\n",
      "Epoch [10/10], Step [100/250], Loss: 0.7134058475494385\n",
      "Epoch [10/10], Step [110/250], Loss: 0.6768066883087158\n",
      "Epoch [10/10], Step [120/250], Loss: 0.6304113268852234\n",
      "Epoch [10/10], Step [130/250], Loss: 0.7426990866661072\n",
      "Epoch [10/10], Step [140/250], Loss: 0.7469097971916199\n",
      "Epoch [10/10], Step [150/250], Loss: 0.7591947913169861\n",
      "Epoch [10/10], Step [160/250], Loss: 0.7327935099601746\n",
      "Epoch [10/10], Step [170/250], Loss: 0.8590223789215088\n",
      "Epoch [10/10], Step [180/250], Loss: 0.6994909644126892\n",
      "Epoch [10/10], Step [190/250], Loss: 0.8262240886688232\n",
      "Epoch [10/10], Step [200/250], Loss: 0.6071692109107971\n",
      "Epoch [10/10], Step [210/250], Loss: 0.915013313293457\n",
      "Epoch [10/10], Step [220/250], Loss: 0.8758894205093384\n",
      "Epoch [10/10], Step [230/250], Loss: 0.6473208665847778\n",
      "Epoch [10/10], Step [240/250], Loss: 0.6843898296356201\n",
      "Epoch [10/10], Step [250/250], Loss: 0.6645953059196472\n",
      "\n",
      "Training job (train1u1it512gqg) succeeded, you can check the logs/metrics/output in  the console:\n",
      "https://pai.console.aliyun.com/?regionId=cn-hangzhou&workspaceId=58670#/training/jobs/train1u1it512gqg\n"
     ]
    }
   ],
   "source": [
    "from pai.estimator import Estimator\n",
    "from pai.image import retrieve\n",
    "\n",
    "\n",
    "epochs = 10\n",
    "\n",
    "\n",
    "# 训练作业默认会挂载一个OSS Bucket路径到 /ml/output/checkpoints/\n",
    "est = Estimator(\n",
    "    command=\"python train.py --epochs {}\".format(epochs),\n",
    "    source_dir=\"./train_src/\",\n",
    "    image_uri=retrieve(\"PyTorch\", \"latest\").image_uri,\n",
    "    instance_type=\"ecs.c6.large\",\n",
    "    base_job_name=\"torch_checkpoint\",\n",
    ")\n",
    "\n",
    "est.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练作业的checkpoints目录\n",
    "print(est.checkpoints_data())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "以上训练作业对训练数据做了10次迭代，通过使用checkpoint，我们可以在原先模型的基础上继续训练，例如使用训练数据继续迭代20次迭代。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1465353ea22d4b9a86f7b5b892f23471",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading file: /var/folders/hc/5w4bg25j1ns2mm0yb06zzzbh0000gp/T/tmpshzpdx_z/source.tar.gz:   0%|          | 0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "View the job detail by accessing the console URI: https://pai.console.aliyun.com/?regionId=cn-hangzhou&workspaceId=58670#/training/jobs/trainu90lc57j1vm\n",
      "TrainingJob launch starting\n",
      "MAX_PARALLELISM=0\n",
      "C_INCLUDE_PATH=/home/pai/include\n",
      "KUBERNETES_SERVICE_PORT=443\n",
      "KUBERNETES_PORT=tcp://10.192.0.1:443\n",
      "LANGUAGE=en_US.UTF-8\n",
      "PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com\n",
      "MASTER_ADDR=trainu90lc57j1vm-master-0\n",
      "HOSTNAME=trainu90lc57j1vm-master-0\n",
      "LD_LIBRARY_PATH=:/lib/x86_64-linux-gnu:/home/pai/lib:/home/pai/jre/lib/amd64/server\n",
      "MASTER_PORT=23456\n",
      "HOME=/root\n",
      "PAI_USER_ARGS=\n",
      "PYTHONUNBUFFERED=0\n",
      "PAI_OUTPUT_CHECKPOINTS=/ml/output/checkpoints/\n",
      "PAI_CONFIG_DIR=/ml/input/config/\n",
      "WORLD_SIZE=1\n",
      "REGION_ID=cn-hangzhou\n",
      "CPLUS_INCLUDE_PATH=/home/pai/include\n",
      "RANK=0\n",
      "OPAL_PREFIX=/home/pai/\n",
      "PAI_TRAINING_JOB_ID=trainu90lc57j1vm\n",
      "TERM=xterm-color\n",
      "KUBERNETES_PORT_443_TCP_ADDR=10.192.0.1\n",
      "PAI_OUTPUT_MODEL=/ml/output/model/\n",
      "ELASTIC_TRAINING_ENABLED=false\n",
      "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/home/pai/bin:/home/pai/hadoop/bin\n",
      "PIP_INDEX_URL=https://mirrors.cloud.aliyuncs.com/pypi/simple\n",
      "KUBERNETES_PORT_443_TCP_PORT=443\n",
      "KUBERNETES_PORT_443_TCP_PROTO=tcp\n",
      "LANG=en_US.UTF-8\n",
      "PAI_TRAINING_USE_ECI=true\n",
      "aliyun_logs_containerType_tags=containerType=Algorithm\n",
      "KUBERNETES_PORT_443_TCP=tcp://10.192.0.1:443\n",
      "KUBERNETES_SERVICE_PORT_HTTPS=443\n",
      "ELASTIC_INFERENCE_ENABLED=false\n",
      "LC_ALL=en_US.UTF-8\n",
      "JAVA_HOME=/home/pai\n",
      "KUBERNETES_SERVICE_HOST=10.192.0.1\n",
      "PWD=/\n",
      "PAI_HPS={}\n",
      "TZ=UTC\n",
      "HADOOP_HOME=/home/pai/hadoop\n",
      "PAI_OUTPUT_LOGS=/ml/output/logs/\n",
      "aliyun_logs_trainingJobId_tags=trainingJobId=trainu90lc57j1vm\n",
      "PAI_ODPS_CREDENTIAL=/ml/input/credential/odps.json\n",
      "PAI_WORKING_DIR=/ml/usercode/\n",
      "Change to Working Directory, /ml/usercode/\n",
      "User program launching\n",
      "-----------------------------------------------------------------\n",
      "30\n",
      "Epoch [11/30], Step [10/250], Loss: 0.678845226764679\n",
      "Epoch [11/30], Step [20/250], Loss: 0.6292213201522827\n",
      "Epoch [11/30], Step [30/250], Loss: 0.6856911182403564\n",
      "Epoch [11/30], Step [40/250], Loss: 0.6147192716598511\n",
      "Epoch [11/30], Step [50/250], Loss: 0.7846511602401733\n",
      "Epoch [11/30], Step [60/250], Loss: 0.6719473004341125\n",
      "Epoch [11/30], Step [70/250], Loss: 0.8227031826972961\n",
      "Epoch [11/30], Step [80/250], Loss: 0.7861220836639404\n",
      "Epoch [11/30], Step [90/250], Loss: 0.7436649203300476\n",
      "Epoch [11/30], Step [100/250], Loss: 0.8053247928619385\n",
      "Epoch [11/30], Step [110/250], Loss: 0.716484546661377\n",
      "Epoch [11/30], Step [120/250], Loss: 0.6527263522148132\n",
      "Epoch [11/30], Step [130/250], Loss: 0.7980918884277344\n",
      "Epoch [11/30], Step [140/250], Loss: 0.6761615872383118\n",
      "Epoch [11/30], Step [150/250], Loss: 0.8030520081520081\n",
      "Epoch [11/30], Step [160/250], Loss: 0.6580255627632141\n",
      "Epoch [11/30], Step [170/250], Loss: 0.7671869993209839\n",
      "Epoch [11/30], Step [180/250], Loss: 0.6622000932693481\n",
      "Epoch [11/30], Step [190/250], Loss: 0.747247576713562\n",
      "Epoch [11/30], Step [200/250], Loss: 0.705307126045227\n",
      "Epoch [11/30], Step [210/250], Loss: 0.6516950130462646\n",
      "Epoch [11/30], Step [220/250], Loss: 0.6065223217010498\n",
      "Epoch [11/30], Step [230/250], Loss: 0.6885045766830444\n",
      "Epoch [11/30], Step [240/250], Loss: 0.7392936944961548\n",
      "Epoch [11/30], Step [250/250], Loss: 0.6803852319717407\n",
      "Epoch [12/30], Step [10/250], Loss: 0.8813486695289612\n",
      "Epoch [12/30], Step [20/250], Loss: 0.7780698537826538\n",
      "Epoch [12/30], Step [30/250], Loss: 0.7158650159835815\n",
      "Epoch [12/30], Step [40/250], Loss: 0.5826153755187988\n",
      "Epoch [12/30], Step [50/250], Loss: 0.6013429760932922\n",
      "Epoch [12/30], Step [60/250], Loss: 0.7084614634513855\n",
      "Epoch [12/30], Step [70/250], Loss: 0.6825753450393677\n",
      "Epoch [12/30], Step [80/250], Loss: 0.6074261665344238\n",
      "Epoch [12/30], Step [90/250], Loss: 0.8619674444198608\n",
      "Epoch [12/30], Step [100/250], Loss: 0.6013283729553223\n",
      "Epoch [12/30], Step [110/250], Loss: 0.6808617115020752\n",
      "Epoch [12/30], Step [120/250], Loss: 0.6765388250350952\n",
      "Epoch [12/30], Step [130/250], Loss: 0.7072106599807739\n",
      "Epoch [12/30], Step [140/250], Loss: 0.6905199289321899\n",
      "Epoch [12/30], Step [150/250], Loss: 0.6942532062530518\n",
      "Epoch [12/30], Step [160/250], Loss: 0.7181805968284607\n",
      "Epoch [12/30], Step [170/250], Loss: 0.6357207298278809\n",
      "Epoch [12/30], Step [180/250], Loss: 0.6719130277633667\n",
      "Epoch [12/30], Step [190/250], Loss: 0.7218160629272461\n",
      "Epoch [12/30], Step [200/250], Loss: 0.7158771753311157\n",
      "Epoch [12/30], Step [210/250], Loss: 0.7585588693618774\n",
      "Epoch [12/30], Step [220/250], Loss: 0.8121419548988342\n",
      "Epoch [12/30], Step [230/250], Loss: 0.7744668126106262\n",
      "Epoch [12/30], Step [240/250], Loss: 0.7164073586463928\n",
      "Epoch [12/30], Step [250/250], Loss: 0.5488151907920837\n",
      "Epoch [13/30], Step [10/250], Loss: 0.7662173509597778\n",
      "Epoch [13/30], Step [20/250], Loss: 0.7802825570106506\n",
      "Epoch [13/30], Step [30/250], Loss: 0.7456352114677429\n",
      "Epoch [13/30], Step [40/250], Loss: 0.6143842935562134\n",
      "Epoch [13/30], Step [50/250], Loss: 0.7393404245376587\n",
      "Epoch [13/30], Step [60/250], Loss: 0.6536136865615845\n",
      "Epoch [13/30], Step [70/250], Loss: 0.7647539377212524\n",
      "Epoch [13/30], Step [80/250], Loss: 0.6415259838104248\n",
      "Epoch [13/30], Step [90/250], Loss: 0.8065975904464722\n",
      "Epoch [13/30], Step [100/250], Loss: 0.654565155506134\n",
      "Epoch [13/30], Step [110/250], Loss: 0.6512014865875244\n",
      "Epoch [13/30], Step [120/250], Loss: 0.6851429343223572\n",
      "Epoch [13/30], Step [130/250], Loss: 0.7639355659484863\n",
      "Epoch [13/30], Step [140/250], Loss: 0.7886079549789429\n",
      "Epoch [13/30], Step [150/250], Loss: 0.677024245262146\n",
      "Epoch [13/30], Step [160/250], Loss: 0.6869807243347168\n",
      "Epoch [13/30], Step [170/250], Loss: 0.7076682448387146\n",
      "Epoch [13/30], Step [180/250], Loss: 0.6720783710479736\n",
      "Epoch [13/30], Step [190/250], Loss: 0.6578226685523987\n",
      "Epoch [13/30], Step [200/250], Loss: 0.6924010515213013\n",
      "Epoch [13/30], Step [210/250], Loss: 0.8084946870803833\n",
      "Epoch [13/30], Step [220/250], Loss: 0.7015032768249512\n",
      "Epoch [13/30], Step [230/250], Loss: 0.6897311210632324\n",
      "Epoch [13/30], Step [240/250], Loss: 0.7233715653419495\n",
      "Epoch [13/30], Step [250/250], Loss: 0.82469242811203\n",
      "Epoch [14/30], Step [10/250], Loss: 0.7118442058563232\n",
      "Epoch [14/30], Step [20/250], Loss: 0.66881263256073\n",
      "Epoch [14/30], Step [30/250], Loss: 0.6966590881347656\n",
      "Epoch [14/30], Step [40/250], Loss: 0.8390185236930847\n",
      "Epoch [14/30], Step [50/250], Loss: 0.7978378534317017\n",
      "Epoch [14/30], Step [60/250], Loss: 0.6207278966903687\n",
      "Epoch [14/30], Step [70/250], Loss: 0.6512827277183533\n",
      "Epoch [14/30], Step [80/250], Loss: 0.6850301027297974\n",
      "Epoch [14/30], Step [90/250], Loss: 0.628646194934845\n",
      "Epoch [14/30], Step [100/250], Loss: 0.6093996167182922\n",
      "Epoch [14/30], Step [110/250], Loss: 0.7588788866996765\n",
      "Epoch [14/30], Step [120/250], Loss: 0.6795099377632141\n",
      "Epoch [14/30], Step [130/250], Loss: 0.6357916593551636\n",
      "Epoch [14/30], Step [140/250], Loss: 0.7358158826828003\n",
      "Epoch [14/30], Step [150/250], Loss: 0.6896149516105652\n",
      "Epoch [14/30], Step [160/250], Loss: 0.6862155199050903\n",
      "Epoch [14/30], Step [170/250], Loss: 0.659408688545227\n",
      "Epoch [14/30], Step [180/250], Loss: 0.717597246170044\n",
      "Epoch [14/30], Step [190/250], Loss: 0.6779205203056335\n",
      "Epoch [14/30], Step [200/250], Loss: 0.6569654941558838\n",
      "Epoch [14/30], Step [210/250], Loss: 0.6521044373512268\n",
      "Epoch [14/30], Step [220/250], Loss: 0.5803452134132385\n",
      "Epoch [14/30], Step [230/250], Loss: 0.6112836599349976\n",
      "Epoch [14/30], Step [240/250], Loss: 0.6311125755310059\n",
      "Epoch [14/30], Step [250/250], Loss: 0.6427040696144104\n",
      "Epoch [15/30], Step [10/250], Loss: 0.7193827629089355\n",
      "Epoch [15/30], Step [20/250], Loss: 0.6781796216964722\n",
      "Epoch [15/30], Step [30/250], Loss: 0.7042354345321655\n",
      "Epoch [15/30], Step [40/250], Loss: 0.6776638627052307\n",
      "Epoch [15/30], Step [50/250], Loss: 0.6593765020370483\n",
      "Epoch [15/30], Step [60/250], Loss: 0.6749820113182068\n",
      "Epoch [15/30], Step [70/250], Loss: 0.6199281811714172\n",
      "Epoch [15/30], Step [80/250], Loss: 0.6898410320281982\n",
      "Epoch [15/30], Step [90/250], Loss: 0.6938673257827759\n",
      "Epoch [15/30], Step [100/250], Loss: 0.6369883418083191\n",
      "Epoch [15/30], Step [110/250], Loss: 0.6758348345756531\n",
      "Epoch [15/30], Step [120/250], Loss: 0.7379288673400879\n",
      "Epoch [15/30], Step [130/250], Loss: 0.6447997689247131\n",
      "Epoch [15/30], Step [140/250], Loss: 0.6910532712936401\n",
      "Epoch [15/30], Step [150/250], Loss: 0.7426170110702515\n",
      "Epoch [15/30], Step [160/250], Loss: 0.6422319412231445\n",
      "Epoch [15/30], Step [170/250], Loss: 0.5789802670478821\n",
      "Epoch [15/30], Step [180/250], Loss: 0.7434327602386475\n",
      "Epoch [15/30], Step [190/250], Loss: 0.6754781007766724\n",
      "Epoch [15/30], Step [200/250], Loss: 0.5865523815155029\n",
      "Epoch [15/30], Step [210/250], Loss: 0.6548283696174622\n",
      "Epoch [15/30], Step [220/250], Loss: 0.7495550513267517\n",
      "Epoch [15/30], Step [230/250], Loss: 0.6538060903549194\n",
      "Epoch [15/30], Step [240/250], Loss: 0.7314434051513672\n",
      "Epoch [15/30], Step [250/250], Loss: 0.7135218381881714\n",
      "Epoch [16/30], Step [10/250], Loss: 0.7383496761322021\n",
      "Epoch [16/30], Step [20/250], Loss: 0.644036591053009\n",
      "Epoch [16/30], Step [30/250], Loss: 0.6101108193397522\n",
      "Epoch [16/30], Step [40/250], Loss: 0.7390760779380798\n",
      "Epoch [16/30], Step [50/250], Loss: 0.6870918273925781\n",
      "Epoch [16/30], Step [60/250], Loss: 0.6894906759262085\n",
      "Epoch [16/30], Step [70/250], Loss: 0.7674188017845154\n",
      "Epoch [16/30], Step [80/250], Loss: 0.7476275563240051\n",
      "Epoch [16/30], Step [90/250], Loss: 0.7009009718894958\n",
      "Epoch [16/30], Step [100/250], Loss: 0.6951045989990234\n",
      "Epoch [16/30], Step [110/250], Loss: 0.7023512721061707\n",
      "Epoch [16/30], Step [120/250], Loss: 0.6900476217269897\n",
      "Epoch [16/30], Step [130/250], Loss: 0.7070642709732056\n",
      "Epoch [16/30], Step [140/250], Loss: 0.6627304553985596\n",
      "Epoch [16/30], Step [150/250], Loss: 0.676548182964325\n",
      "Epoch [16/30], Step [160/250], Loss: 0.7038763761520386\n",
      "Epoch [16/30], Step [170/250], Loss: 0.6916297078132629\n",
      "Epoch [16/30], Step [180/250], Loss: 0.7028259634971619\n",
      "Epoch [16/30], Step [190/250], Loss: 0.6524210572242737\n",
      "Epoch [16/30], Step [200/250], Loss: 0.7346513867378235\n",
      "Epoch [16/30], Step [210/250], Loss: 0.612514317035675\n",
      "Epoch [16/30], Step [220/250], Loss: 0.7455917596817017\n",
      "Epoch [16/30], Step [230/250], Loss: 0.747292160987854\n",
      "Epoch [16/30], Step [240/250], Loss: 0.7447240352630615\n",
      "Epoch [16/30], Step [250/250], Loss: 0.6769564747810364\n",
      "Epoch [17/30], Step [10/250], Loss: 0.7425077557563782\n",
      "Epoch [17/30], Step [20/250], Loss: 0.6944329738616943\n",
      "Epoch [17/30], Step [30/250], Loss: 0.6961978673934937\n",
      "Epoch [17/30], Step [40/250], Loss: 0.6465986967086792\n",
      "Epoch [17/30], Step [50/250], Loss: 0.714703381061554\n",
      "Epoch [17/30], Step [60/250], Loss: 0.5930614471435547\n",
      "Epoch [17/30], Step [70/250], Loss: 0.6468428373336792\n",
      "Epoch [17/30], Step [80/250], Loss: 0.686537504196167\n",
      "Epoch [17/30], Step [90/250], Loss: 0.7371711730957031\n",
      "Epoch [17/30], Step [100/250], Loss: 0.7700399160385132\n",
      "Epoch [17/30], Step [110/250], Loss: 0.7529278993606567\n",
      "Epoch [17/30], Step [120/250], Loss: 0.7036042213439941\n",
      "Epoch [17/30], Step [130/250], Loss: 0.7871543765068054\n",
      "Epoch [17/30], Step [140/250], Loss: 0.6956086158752441\n",
      "Epoch [17/30], Step [150/250], Loss: 0.7426921725273132\n",
      "Epoch [17/30], Step [160/250], Loss: 0.7222756743431091\n",
      "Epoch [17/30], Step [170/250], Loss: 0.6826121807098389\n",
      "Epoch [17/30], Step [180/250], Loss: 0.6970388293266296\n",
      "Epoch [17/30], Step [190/250], Loss: 0.7087472677230835\n",
      "Epoch [17/30], Step [200/250], Loss: 0.6320711374282837\n",
      "Epoch [17/30], Step [210/250], Loss: 0.7280303835868835\n",
      "Epoch [17/30], Step [220/250], Loss: 0.6934517621994019\n",
      "Epoch [17/30], Step [230/250], Loss: 0.7071420550346375\n",
      "Epoch [17/30], Step [240/250], Loss: 0.6856362223625183\n",
      "Epoch [17/30], Step [250/250], Loss: 0.6945990324020386\n",
      "Epoch [18/30], Step [10/250], Loss: 0.6465855240821838\n",
      "Epoch [18/30], Step [20/250], Loss: 0.7086865901947021\n",
      "Epoch [18/30], Step [30/250], Loss: 0.6256162524223328\n",
      "Epoch [18/30], Step [40/250], Loss: 0.6532611846923828\n",
      "Epoch [18/30], Step [50/250], Loss: 0.6484596729278564\n",
      "Epoch [18/30], Step [60/250], Loss: 0.6955176591873169\n",
      "Epoch [18/30], Step [70/250], Loss: 0.6615030765533447\n",
      "Epoch [18/30], Step [80/250], Loss: 0.7038217186927795\n",
      "Epoch [18/30], Step [90/250], Loss: 0.6943345069885254\n",
      "Epoch [18/30], Step [100/250], Loss: 0.7004052996635437\n",
      "Epoch [18/30], Step [110/250], Loss: 0.7458634972572327\n",
      "Epoch [18/30], Step [120/250], Loss: 0.6851629614830017\n",
      "Epoch [18/30], Step [130/250], Loss: 0.682853102684021\n",
      "Epoch [18/30], Step [140/250], Loss: 0.6481672525405884\n",
      "Epoch [18/30], Step [150/250], Loss: 0.7038549780845642\n",
      "Epoch [18/30], Step [160/250], Loss: 0.6995554566383362\n",
      "Epoch [18/30], Step [170/250], Loss: 0.6800370216369629\n",
      "Epoch [18/30], Step [180/250], Loss: 0.6488386392593384\n",
      "Epoch [18/30], Step [190/250], Loss: 0.7000787854194641\n",
      "Epoch [18/30], Step [200/250], Loss: 0.7428950071334839\n",
      "Epoch [18/30], Step [210/250], Loss: 0.6872988343238831\n",
      "Epoch [18/30], Step [220/250], Loss: 0.6482336521148682\n",
      "Epoch [18/30], Step [230/250], Loss: 0.6626957058906555\n",
      "Epoch [18/30], Step [240/250], Loss: 0.6778802275657654\n",
      "Epoch [18/30], Step [250/250], Loss: 0.7027387022972107\n",
      "Epoch [19/30], Step [10/250], Loss: 0.6812503933906555\n",
      "Epoch [19/30], Step [20/250], Loss: 0.6751934289932251\n",
      "Epoch [19/30], Step [30/250], Loss: 0.6624279618263245\n",
      "Epoch [19/30], Step [40/250], Loss: 0.6787773966789246\n",
      "Epoch [19/30], Step [50/250], Loss: 0.7765601873397827\n",
      "Epoch [19/30], Step [60/250], Loss: 0.6592363119125366\n",
      "Epoch [19/30], Step [70/250], Loss: 0.7038179039955139\n",
      "Epoch [19/30], Step [80/250], Loss: 0.7358537316322327\n",
      "Epoch [19/30], Step [90/250], Loss: 0.708828330039978\n",
      "Epoch [19/30], Step [100/250], Loss: 0.7642552852630615\n",
      "Epoch [19/30], Step [110/250], Loss: 0.7605912089347839\n",
      "Epoch [19/30], Step [120/250], Loss: 0.6976773738861084\n",
      "Epoch [19/30], Step [130/250], Loss: 0.6766220331192017\n",
      "Epoch [19/30], Step [140/250], Loss: 0.7171740531921387\n",
      "Epoch [19/30], Step [150/250], Loss: 0.6521143913269043\n",
      "Epoch [19/30], Step [160/250], Loss: 0.6554864645004272\n",
      "Epoch [19/30], Step [170/250], Loss: 0.6797289848327637\n",
      "Epoch [19/30], Step [180/250], Loss: 0.6546230316162109\n",
      "Epoch [19/30], Step [190/250], Loss: 0.6951708197593689\n",
      "Epoch [19/30], Step [200/250], Loss: 0.7692861557006836\n",
      "Epoch [19/30], Step [210/250], Loss: 0.6987319588661194\n",
      "Epoch [19/30], Step [220/250], Loss: 0.7281709909439087\n",
      "Epoch [19/30], Step [230/250], Loss: 0.6981549263000488\n",
      "Epoch [19/30], Step [240/250], Loss: 0.6613932847976685\n",
      "Epoch [19/30], Step [250/250], Loss: 0.6515719890594482\n",
      "Epoch [20/30], Step [10/250], Loss: 0.683667004108429\n",
      "Epoch [20/30], Step [20/250], Loss: 0.6330690383911133\n",
      "Epoch [20/30], Step [30/250], Loss: 0.6992578506469727\n",
      "Epoch [20/30], Step [40/250], Loss: 0.7081963419914246\n",
      "Epoch [20/30], Step [50/250], Loss: 0.7147829532623291\n",
      "Epoch [20/30], Step [60/250], Loss: 0.6547238826751709\n",
      "Epoch [20/30], Step [70/250], Loss: 0.627391517162323\n",
      "Epoch [20/30], Step [80/250], Loss: 0.6972628831863403\n",
      "Epoch [20/30], Step [90/250], Loss: 0.6500757932662964\n",
      "Epoch [20/30], Step [100/250], Loss: 0.7282431125640869\n",
      "Epoch [20/30], Step [110/250], Loss: 0.6599644422531128\n",
      "Epoch [20/30], Step [120/250], Loss: 0.691277265548706\n",
      "Epoch [20/30], Step [130/250], Loss: 0.6712023019790649\n",
      "Epoch [20/30], Step [140/250], Loss: 0.6875613927841187\n",
      "Epoch [20/30], Step [150/250], Loss: 0.6852554082870483\n",
      "Epoch [20/30], Step [160/250], Loss: 0.7059615850448608\n",
      "Epoch [20/30], Step [170/250], Loss: 0.7474350333213806\n",
      "Epoch [20/30], Step [180/250], Loss: 0.6700282096862793\n",
      "Epoch [20/30], Step [190/250], Loss: 0.7267058491706848\n",
      "Epoch [20/30], Step [200/250], Loss: 0.6795942783355713\n",
      "Epoch [20/30], Step [210/250], Loss: 0.7355214953422546\n",
      "Epoch [20/30], Step [220/250], Loss: 0.7097989320755005\n",
      "Epoch [20/30], Step [230/250], Loss: 0.6741981506347656\n",
      "Epoch [20/30], Step [240/250], Loss: 0.7197920680046082\n",
      "Epoch [20/30], Step [250/250], Loss: 0.6666856408119202\n",
      "Epoch [21/30], Step [10/250], Loss: 0.6850540637969971\n",
      "Epoch [21/30], Step [20/250], Loss: 0.6577891111373901\n",
      "Epoch [21/30], Step [30/250], Loss: 0.7145082354545593\n",
      "Epoch [21/30], Step [40/250], Loss: 0.6782787442207336\n",
      "Epoch [21/30], Step [50/250], Loss: 0.7092875242233276\n",
      "Epoch [21/30], Step [60/250], Loss: 0.6552045941352844\n",
      "Epoch [21/30], Step [70/250], Loss: 0.665422260761261\n",
      "Epoch [21/30], Step [80/250], Loss: 0.7131606340408325\n",
      "Epoch [21/30], Step [90/250], Loss: 0.6851215362548828\n",
      "Epoch [21/30], Step [100/250], Loss: 0.7093809843063354\n",
      "Epoch [21/30], Step [110/250], Loss: 0.6839103698730469\n",
      "Epoch [21/30], Step [120/250], Loss: 0.6863808035850525\n",
      "Epoch [21/30], Step [130/250], Loss: 0.6923962831497192\n",
      "Epoch [21/30], Step [140/250], Loss: 0.7143585085868835\n",
      "Epoch [21/30], Step [150/250], Loss: 0.7165741324424744\n",
      "Epoch [21/30], Step [160/250], Loss: 0.7011140584945679\n",
      "Epoch [21/30], Step [170/250], Loss: 0.7145777344703674\n",
      "Epoch [21/30], Step [180/250], Loss: 0.6781455278396606\n",
      "Epoch [21/30], Step [190/250], Loss: 0.704175591468811\n",
      "Epoch [21/30], Step [200/250], Loss: 0.6643280982971191\n",
      "Epoch [21/30], Step [210/250], Loss: 0.7143128514289856\n",
      "Epoch [21/30], Step [220/250], Loss: 0.7122169137001038\n",
      "Epoch [21/30], Step [230/250], Loss: 0.7329443693161011\n",
      "Epoch [21/30], Step [240/250], Loss: 0.7038950324058533\n",
      "Epoch [21/30], Step [250/250], Loss: 0.683397114276886\n",
      "Epoch [22/30], Step [10/250], Loss: 0.6960069537162781\n",
      "Epoch [22/30], Step [20/250], Loss: 0.6595947742462158\n",
      "Epoch [22/30], Step [30/250], Loss: 0.7287018895149231\n",
      "Epoch [22/30], Step [40/250], Loss: 0.7046036720275879\n",
      "Epoch [22/30], Step [50/250], Loss: 0.7062811255455017\n",
      "Epoch [22/30], Step [60/250], Loss: 0.7442296743392944\n",
      "Epoch [22/30], Step [70/250], Loss: 0.6482053399085999\n",
      "Epoch [22/30], Step [80/250], Loss: 0.722833514213562\n",
      "Epoch [22/30], Step [90/250], Loss: 0.6747336387634277\n",
      "Epoch [22/30], Step [100/250], Loss: 0.7139792442321777\n",
      "Epoch [22/30], Step [110/250], Loss: 0.680081844329834\n",
      "Epoch [22/30], Step [120/250], Loss: 0.686549186706543\n",
      "Epoch [22/30], Step [130/250], Loss: 0.6854720115661621\n",
      "Epoch [22/30], Step [140/250], Loss: 0.6525530815124512\n",
      "Epoch [22/30], Step [150/250], Loss: 0.6676555871963501\n",
      "Epoch [22/30], Step [160/250], Loss: 0.7014628052711487\n",
      "Epoch [22/30], Step [170/250], Loss: 0.7186480760574341\n",
      "Epoch [22/30], Step [180/250], Loss: 0.6748342514038086\n",
      "Epoch [22/30], Step [190/250], Loss: 0.7034397125244141\n",
      "Epoch [22/30], Step [200/250], Loss: 0.6637327075004578\n",
      "Epoch [22/30], Step [210/250], Loss: 0.6852638125419617\n",
      "Epoch [22/30], Step [220/250], Loss: 0.6631066203117371\n",
      "Epoch [22/30], Step [230/250], Loss: 0.7248471975326538\n",
      "Epoch [22/30], Step [240/250], Loss: 0.7282781004905701\n",
      "Epoch [22/30], Step [250/250], Loss: 0.678613007068634\n",
      "Epoch [23/30], Step [10/250], Loss: 0.6844161748886108\n",
      "Epoch [23/30], Step [20/250], Loss: 0.6881325244903564\n",
      "Epoch [23/30], Step [30/250], Loss: 0.6631232500076294\n",
      "Epoch [23/30], Step [40/250], Loss: 0.7202731370925903\n",
      "Epoch [23/30], Step [50/250], Loss: 0.6977999210357666\n",
      "Epoch [23/30], Step [60/250], Loss: 0.7103397846221924\n",
      "Epoch [23/30], Step [70/250], Loss: 0.6726264953613281\n",
      "Epoch [23/30], Step [80/250], Loss: 0.6642501354217529\n",
      "Epoch [23/30], Step [90/250], Loss: 0.7357184886932373\n",
      "Epoch [23/30], Step [100/250], Loss: 0.7160366773605347\n",
      "Epoch [23/30], Step [110/250], Loss: 0.6603021621704102\n",
      "Epoch [23/30], Step [120/250], Loss: 0.6760040521621704\n",
      "Epoch [23/30], Step [130/250], Loss: 0.696141242980957\n",
      "Epoch [23/30], Step [140/250], Loss: 0.6645365357398987\n",
      "Epoch [23/30], Step [150/250], Loss: 0.7011918425559998\n",
      "Epoch [23/30], Step [160/250], Loss: 0.6758050322532654\n",
      "Epoch [23/30], Step [170/250], Loss: 0.6683043837547302\n",
      "Epoch [23/30], Step [180/250], Loss: 0.6827936172485352\n",
      "Epoch [23/30], Step [190/250], Loss: 0.699557900428772\n",
      "Epoch [23/30], Step [200/250], Loss: 0.6873543858528137\n",
      "Epoch [23/30], Step [210/250], Loss: 0.6973046064376831\n",
      "Epoch [23/30], Step [220/250], Loss: 0.6847941279411316\n",
      "Epoch [23/30], Step [230/250], Loss: 0.686026930809021\n",
      "Epoch [23/30], Step [240/250], Loss: 0.712138831615448\n",
      "Epoch [23/30], Step [250/250], Loss: 0.6938803791999817\n",
      "Epoch [24/30], Step [10/250], Loss: 0.6833834648132324\n",
      "Epoch [24/30], Step [20/250], Loss: 0.7029370069503784\n",
      "Epoch [24/30], Step [30/250], Loss: 0.6896952390670776\n",
      "Epoch [24/30], Step [40/250], Loss: 0.6966062784194946\n",
      "Epoch [24/30], Step [50/250], Loss: 0.6755800247192383\n",
      "Epoch [24/30], Step [60/250], Loss: 0.6890952587127686\n",
      "Epoch [24/30], Step [70/250], Loss: 0.6705589294433594\n",
      "Epoch [24/30], Step [80/250], Loss: 0.7066176533699036\n",
      "Epoch [24/30], Step [90/250], Loss: 0.758873701095581\n",
      "Epoch [24/30], Step [100/250], Loss: 0.699566125869751\n",
      "Epoch [24/30], Step [110/250], Loss: 0.7008506059646606\n",
      "Epoch [24/30], Step [120/250], Loss: 0.686880350112915\n",
      "Epoch [24/30], Step [130/250], Loss: 0.6831185817718506\n",
      "Epoch [24/30], Step [140/250], Loss: 0.6989403963088989\n",
      "Epoch [24/30], Step [150/250], Loss: 0.7022895812988281\n",
      "Epoch [24/30], Step [160/250], Loss: 0.7047298550605774\n",
      "Epoch [24/30], Step [170/250], Loss: 0.6803637742996216\n",
      "Epoch [24/30], Step [180/250], Loss: 0.6698098182678223\n",
      "Epoch [24/30], Step [190/250], Loss: 0.6965357661247253\n",
      "Epoch [24/30], Step [200/250], Loss: 0.7183314561843872\n",
      "Epoch [24/30], Step [210/250], Loss: 0.7083855271339417\n",
      "Epoch [24/30], Step [220/250], Loss: 0.688880205154419\n",
      "Epoch [24/30], Step [230/250], Loss: 0.6859614253044128\n",
      "Epoch [24/30], Step [240/250], Loss: 0.6815621852874756\n",
      "Epoch [24/30], Step [250/250], Loss: 0.7023071050643921\n",
      "Epoch [25/30], Step [10/250], Loss: 0.6979001760482788\n",
      "Epoch [25/30], Step [20/250], Loss: 0.6792093515396118\n",
      "Epoch [25/30], Step [30/250], Loss: 0.7000377178192139\n",
      "Epoch [25/30], Step [40/250], Loss: 0.6891401410102844\n",
      "Epoch [25/30], Step [50/250], Loss: 0.6950706839561462\n",
      "Epoch [25/30], Step [60/250], Loss: 0.6931962966918945\n",
      "Epoch [25/30], Step [70/250], Loss: 0.6918748021125793\n",
      "Epoch [25/30], Step [80/250], Loss: 0.7022840976715088\n",
      "Epoch [25/30], Step [90/250], Loss: 0.7233110666275024\n",
      "Epoch [25/30], Step [100/250], Loss: 0.6882573366165161\n",
      "Epoch [25/30], Step [110/250], Loss: 0.6959525346755981\n",
      "Epoch [25/30], Step [120/250], Loss: 0.6953780651092529\n",
      "Epoch [25/30], Step [130/250], Loss: 0.7029913067817688\n",
      "Epoch [25/30], Step [140/250], Loss: 0.7104859948158264\n",
      "Epoch [25/30], Step [150/250], Loss: 0.6983399391174316\n",
      "Epoch [25/30], Step [160/250], Loss: 0.6920713186264038\n",
      "Epoch [25/30], Step [170/250], Loss: 0.7179511189460754\n",
      "Epoch [25/30], Step [180/250], Loss: 0.6971415281295776\n",
      "Epoch [25/30], Step [190/250], Loss: 0.7037041783332825\n",
      "Epoch [25/30], Step [200/250], Loss: 0.6952695846557617\n",
      "Epoch [25/30], Step [210/250], Loss: 0.7007227540016174\n",
      "Epoch [25/30], Step [220/250], Loss: 0.686070442199707\n",
      "Epoch [25/30], Step [230/250], Loss: 0.692324161529541\n",
      "Epoch [25/30], Step [240/250], Loss: 0.6936407089233398\n",
      "Epoch [25/30], Step [250/250], Loss: 0.6896817088127136\n",
      "Epoch [26/30], Step [10/250], Loss: 0.7085744142532349\n",
      "Epoch [26/30], Step [20/250], Loss: 0.6863793730735779\n",
      "Epoch [26/30], Step [30/250], Loss: 0.6817866563796997\n",
      "Epoch [26/30], Step [40/250], Loss: 0.7037662267684937\n",
      "Epoch [26/30], Step [50/250], Loss: 0.7046667337417603\n",
      "Epoch [26/30], Step [60/250], Loss: 0.6918007135391235\n",
      "Epoch [26/30], Step [70/250], Loss: 0.713044285774231\n",
      "Epoch [26/30], Step [80/250], Loss: 0.6832862496376038\n",
      "Epoch [26/30], Step [90/250], Loss: 0.667504608631134\n",
      "Epoch [26/30], Step [100/250], Loss: 0.6760569214820862\n",
      "Epoch [26/30], Step [110/250], Loss: 0.707482099533081\n",
      "Epoch [26/30], Step [120/250], Loss: 0.6977518200874329\n",
      "Epoch [26/30], Step [130/250], Loss: 0.6955530047416687\n",
      "Epoch [26/30], Step [140/250], Loss: 0.7124805450439453\n",
      "Epoch [26/30], Step [150/250], Loss: 0.6924611330032349\n",
      "Epoch [26/30], Step [160/250], Loss: 0.6965060234069824\n",
      "Epoch [26/30], Step [170/250], Loss: 0.6868378520011902\n",
      "Epoch [26/30], Step [180/250], Loss: 0.7103825807571411\n",
      "Epoch [26/30], Step [190/250], Loss: 0.6711806654930115\n",
      "Epoch [26/30], Step [200/250], Loss: 0.6948347091674805\n",
      "Epoch [26/30], Step [210/250], Loss: 0.7058894634246826\n",
      "Epoch [26/30], Step [220/250], Loss: 0.6947336196899414\n",
      "Epoch [26/30], Step [230/250], Loss: 0.689943253993988\n",
      "Epoch [26/30], Step [240/250], Loss: 0.6956008672714233\n",
      "Epoch [26/30], Step [250/250], Loss: 0.6892440319061279\n",
      "Epoch [27/30], Step [10/250], Loss: 0.6945648193359375\n",
      "Epoch [27/30], Step [20/250], Loss: 0.697243332862854\n",
      "Epoch [27/30], Step [30/250], Loss: 0.6995589137077332\n",
      "Epoch [27/30], Step [40/250], Loss: 0.6961522698402405\n",
      "Epoch [27/30], Step [50/250], Loss: 0.7141368389129639\n",
      "Epoch [27/30], Step [60/250], Loss: 0.6883167028427124\n",
      "Epoch [27/30], Step [70/250], Loss: 0.681597888469696\n",
      "Epoch [27/30], Step [80/250], Loss: 0.6933290362358093\n",
      "Epoch [27/30], Step [90/250], Loss: 0.6990853548049927\n",
      "Epoch [27/30], Step [100/250], Loss: 0.6930828094482422\n",
      "Epoch [27/30], Step [110/250], Loss: 0.6889819502830505\n",
      "Epoch [27/30], Step [120/250], Loss: 0.6966762542724609\n",
      "Epoch [27/30], Step [130/250], Loss: 0.7014245986938477\n",
      "Epoch [27/30], Step [140/250], Loss: 0.7081984281539917\n",
      "Epoch [27/30], Step [150/250], Loss: 0.6894259452819824\n",
      "Epoch [27/30], Step [160/250], Loss: 0.695622444152832\n",
      "Epoch [27/30], Step [170/250], Loss: 0.6961721181869507\n",
      "Epoch [27/30], Step [180/250], Loss: 0.6897941827774048\n",
      "Epoch [27/30], Step [190/250], Loss: 0.6890014410018921\n",
      "Epoch [27/30], Step [200/250], Loss: 0.6775841116905212\n",
      "Epoch [27/30], Step [210/250], Loss: 0.6889995336532593\n",
      "Epoch [27/30], Step [220/250], Loss: 0.6887487769126892\n",
      "Epoch [27/30], Step [230/250], Loss: 0.6713950037956238\n",
      "Epoch [27/30], Step [240/250], Loss: 0.6815714836120605\n",
      "Epoch [27/30], Step [250/250], Loss: 0.6999087333679199\n",
      "Epoch [28/30], Step [10/250], Loss: 0.7005322575569153\n",
      "Epoch [28/30], Step [20/250], Loss: 0.6854400634765625\n",
      "Epoch [28/30], Step [30/250], Loss: 0.7016850113868713\n",
      "Epoch [28/30], Step [40/250], Loss: 0.6971641182899475\n",
      "Epoch [28/30], Step [50/250], Loss: 0.6831482648849487\n",
      "Epoch [28/30], Step [60/250], Loss: 0.6957387924194336\n",
      "Epoch [28/30], Step [70/250], Loss: 0.6991732716560364\n",
      "Epoch [28/30], Step [80/250], Loss: 0.6832884550094604\n",
      "Epoch [28/30], Step [90/250], Loss: 0.6862078309059143\n",
      "Epoch [28/30], Step [100/250], Loss: 0.7001485824584961\n",
      "Epoch [28/30], Step [110/250], Loss: 0.686698317527771\n",
      "Epoch [28/30], Step [120/250], Loss: 0.6935960054397583\n",
      "Epoch [28/30], Step [130/250], Loss: 0.6797569990158081\n",
      "Epoch [28/30], Step [140/250], Loss: 0.6913435459136963\n",
      "Epoch [28/30], Step [150/250], Loss: 0.7099695205688477\n",
      "Epoch [28/30], Step [160/250], Loss: 0.6739814877510071\n",
      "Epoch [28/30], Step [170/250], Loss: 0.691004753112793\n",
      "Epoch [28/30], Step [180/250], Loss: 0.6871265172958374\n",
      "Epoch [28/30], Step [190/250], Loss: 0.6769859790802002\n",
      "Epoch [28/30], Step [200/250], Loss: 0.6753854751586914\n",
      "Epoch [28/30], Step [210/250], Loss: 0.6798712015151978\n",
      "Epoch [28/30], Step [220/250], Loss: 0.6959697008132935\n",
      "Epoch [28/30], Step [230/250], Loss: 0.6912880539894104\n",
      "Epoch [28/30], Step [240/250], Loss: 0.7011526823043823\n",
      "Epoch [28/30], Step [250/250], Loss: 0.6955965757369995\n",
      "Epoch [29/30], Step [10/250], Loss: 0.700312077999115\n",
      "Epoch [29/30], Step [20/250], Loss: 0.688980758190155\n",
      "Epoch [29/30], Step [30/250], Loss: 0.687660813331604\n",
      "Epoch [29/30], Step [40/250], Loss: 0.6973135471343994\n",
      "Epoch [29/30], Step [50/250], Loss: 0.7041200995445251\n",
      "Epoch [29/30], Step [60/250], Loss: 0.6702690720558167\n",
      "Epoch [29/30], Step [70/250], Loss: 0.695311427116394\n",
      "Epoch [29/30], Step [80/250], Loss: 0.7089749574661255\n",
      "Epoch [29/30], Step [90/250], Loss: 0.6968417763710022\n",
      "Epoch [29/30], Step [100/250], Loss: 0.6854453086853027\n",
      "Epoch [29/30], Step [110/250], Loss: 0.6853547096252441\n",
      "Epoch [29/30], Step [120/250], Loss: 0.6865882277488708\n",
      "Epoch [29/30], Step [130/250], Loss: 0.6883337497711182\n",
      "Epoch [29/30], Step [140/250], Loss: 0.705528974533081\n",
      "Epoch [29/30], Step [150/250], Loss: 0.6866053938865662\n",
      "Epoch [29/30], Step [160/250], Loss: 0.6900249123573303\n",
      "Epoch [29/30], Step [170/250], Loss: 0.6984312534332275\n",
      "Epoch [29/30], Step [180/250], Loss: 0.7001223564147949\n",
      "Epoch [29/30], Step [190/250], Loss: 0.6993950605392456\n",
      "Epoch [29/30], Step [200/250], Loss: 0.6955195069313049\n",
      "Epoch [29/30], Step [210/250], Loss: 0.7174205183982849\n",
      "Epoch [29/30], Step [220/250], Loss: 0.6770732998847961\n",
      "Epoch [29/30], Step [230/250], Loss: 0.6760091781616211\n",
      "Epoch [29/30], Step [240/250], Loss: 0.6769121885299683\n",
      "Epoch [29/30], Step [250/250], Loss: 0.7050588130950928\n",
      "Epoch [30/30], Step [10/250], Loss: 0.6745777130126953\n",
      "Epoch [30/30], Step [20/250], Loss: 0.6881678104400635\n",
      "Epoch [30/30], Step [30/250], Loss: 0.6794246435165405\n",
      "Epoch [30/30], Step [40/250], Loss: 0.7122002840042114\n",
      "Epoch [30/30], Step [50/250], Loss: 0.698681116104126\n",
      "Epoch [30/30], Step [60/250], Loss: 0.7196323871612549\n",
      "Epoch [30/30], Step [70/250], Loss: 0.6916103363037109\n",
      "Epoch [30/30], Step [80/250], Loss: 0.6879148483276367\n",
      "Epoch [30/30], Step [90/250], Loss: 0.7075177431106567\n",
      "Epoch [30/30], Step [100/250], Loss: 0.6686447858810425\n",
      "Epoch [30/30], Step [110/250], Loss: 0.7030155062675476\n",
      "Epoch [30/30], Step [120/250], Loss: 0.7014066576957703\n",
      "Epoch [30/30], Step [130/250], Loss: 0.7121413946151733\n",
      "Epoch [30/30], Step [140/250], Loss: 0.6912719011306763\n",
      "Epoch [30/30], Step [150/250], Loss: 0.6733638048171997\n",
      "Epoch [30/30], Step [160/250], Loss: 0.7193289399147034\n",
      "Epoch [30/30], Step [170/250], Loss: 0.6880522966384888\n",
      "Epoch [30/30], Step [180/250], Loss: 0.7069193720817566\n",
      "Epoch [30/30], Step [190/250], Loss: 0.6976951360702515\n",
      "Epoch [30/30], Step [200/250], Loss: 0.6925494074821472\n",
      "Epoch [30/30], Step [210/250], Loss: 0.6907849907875061\n",
      "Epoch [30/30], Step [220/250], Loss: 0.6824172735214233\n",
      "Epoch [30/30], Step [230/250], Loss: 0.6865588426589966\n",
      "Epoch [30/30], Step [240/250], Loss: 0.6921617984771729\n",
      "Epoch [30/30], Step [250/250], Loss: 0.6736024618148804\n",
      "\n",
      "Training job (trainu90lc57j1vm) succeeded, you can check the logs/metrics/output in  the console:\n",
      "https://pai.console.aliyun.com/?regionId=cn-hangzhou&workspaceId=58670#/training/jobs/trainu90lc57j1vm\n"
     ]
    }
   ],
   "source": [
    "from pai.estimator import Estimator\n",
    "from pai.image import retrieve\n",
    "\n",
    "\n",
    "# 训练数据的总迭代次数为30\n",
    "epochs = 30\n",
    "\n",
    "resume_est = Estimator(\n",
    "    command=\"python train.py --epochs {}\".format(epochs),\n",
    "    source_dir=\"./train_src/\",\n",
    "    image_uri=retrieve(\"PyTorch\", \"latest\").image_uri,\n",
    "    instance_type=\"ecs.c6.large\",\n",
    "    # 使用上一个训练作业的checkpoints，相应的OSS Bucket路径会被挂载到 /ml/output/checkpoints 路径下\n",
    "    checkpoints_path=est.checkpoints_data(),\n",
    "    base_job_name=\"torch_resume_checkpoint\",\n",
    ")\n",
    "\n",
    "resume_est.fit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过训练作业日志的，我们可以看到训练作业加载了之前训练作业的checkpoint，在此基础上，从第11个epoch开始继续训练。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 结语\n",
    "\n",
    "本文以`PyTorch`为示例，介绍了如何在PAI的训练作业中使用`checkpoint`：训练代码可以通过`/ml/output/checkpoints/`路径保存和加载`checkpoints`文件，`checkpoints`文件将被保存到OSS Bucket上。当用户使用其他的训练框架，例如`TensorFlow`、`HuggingFace transformers`、`ModelScope`等，也可以通过类似的方式在PAI的训练作业中使用`checkpoint`。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
